<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>YOLOv8 tinygrad WebGPU</title>
    <script type="module">
        import yolov8 from "./net.js"
        window.yolov8 = yolov8;
    </script>
    <style>
        body {
            text-align: center;
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 0;
            overflow: hidden;
        }

        .video-container {
            position: relative;
            width: 100%;
            height: 100vh;
            margin: 0 auto;
            display: flex;
            align-items: center;
            justify-content: center;
        }

        #video, #canvas {
            position: absolute;
            top: 0;
            left: 0;
            width: 100%;
            height: auto;
        }

        .loader {
            width: 48px;
            height: 48px;
            border: 5px solid #FFF;
            border-bottom-color: transparent;
            border-radius: 50%;
            display: inline-block;
            box-sizing: border-box;
            animation: rotation 1s linear infinite;
        }

        @keyframes rotation {
            0% {
                transform: rotate(0deg);
            }
            100% {
                transform: rotate(360deg);
            }
        }

        #canvas {
            background: transparent;
        }

        #fps-meter {
            position: absolute;
            top: 20px;
            right: 20px;
            background-color: rgba(0, 0, 0, 0.7);
            color: white;
            padding: 10px;
            font-size: 18px;
            border-radius: 5px;
            z-index: 10;
        }

        h1 {
            margin-top: 20px;
        }

        .loading-container {
            display: flex;
            flex-direction: column;
            align-items: center;
            justify-content: center;
            position: fixed;
            top: 0;
            left: 0;
            width: 100%;
            height: 100%;
            background-color: rgba(0, 0, 0, 0.6);
            z-index: 10;
        }

        .loading-text {
            font-size: 24px;
            color: white;
            margin-bottom: 20px;
        }
    </style>
</head>
<body>
    <h2>YOLOv8 tinygrad WebGPU</h2>
    <h2 id="wgpu-error" style="display: none; color: red;">Error: WebGPU is not supported in this browser</h2>
    <div class="video-container">
        <video id="video" muted autoplay playsinline></video>
        <canvas id="canvas"></canvas>
        <div id="fps-meter"></div>

        <div id="div-loading" class="loading-container">
            <p class="loading-text">Loading model</p>
            <span class="loader"></span>
        </div>
    </div>
    <script>
        let net = null;
        const modelInputSize = 416;
        let lastCalledTime;
        let fps = 0, accumFps = 0, frameCounter = 0;

        const video = document.getElementById('video');
        const canvas = document.getElementById('canvas');
        const context = canvas.getContext('2d');
        const offscreenCanvas = document.createElement('canvas');
        const fpsMeter = document.getElementById('fps-meter');
        const loadingContainer = document.getElementById('div-loading');
        const wgpuError = document.getElementById('wgpu-error');
        offscreenCanvas.width = modelInputSize;
        offscreenCanvas.height = modelInputSize;
        const offscreenContext = offscreenCanvas.getContext('2d');


        if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
            navigator.mediaDevices.getUserMedia({ audio: false,  video: { facingMode: { ideal: "environment" }}}).then(function (stream) {
                video.srcObject = stream;
                video.onloadedmetadata = function() {
                    canvas.width = video.clientWidth;
                    canvas.height = video.clientHeight;
                }
            });
        }

        async function processFrame() {
            if (video.videoWidth == 0 || video.videoHeight == 0) {
                requestAnimationFrame(processFrame);
                return;
            }

            if (!lastCalledTime) {
                lastCalledTime = performance.now();
                fps = 0;
            } else {
                const now = performance.now();
                delta = (now - lastCalledTime)/1000.0;
                lastCalledTime = now;
                accumFps += 1/delta;

                if (frameCounter++ >= 10) {
                    fps = accumFps/frameCounter;
                    frameCounter = 0;
                    accumFps = 0;
                    fpsMeter.innerText = `FPS: ${fps.toFixed(1)}`
                }
            }

            const videoAspectRatio = video.videoWidth / video.videoHeight;
            let targetWidth, targetHeight;

            if (videoAspectRatio > 1) {
                targetWidth = modelInputSize;
                targetHeight = modelInputSize / videoAspectRatio;
            } else {
                targetHeight = modelInputSize;
                targetWidth = modelInputSize * videoAspectRatio;
            }

            const offsetX = (modelInputSize - targetWidth) / 2;
            const offsetY = (modelInputSize - targetHeight) / 2;
            offscreenContext.clearRect(0, 0, modelInputSize, modelInputSize);
            offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight);
            const boxes = await detectObjectsOnFrame(offscreenContext);
            drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY);
            requestAnimationFrame(processFrame);
        }

        requestAnimationFrame(processFrame);

        function drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY) {
            const ctx = document.querySelector("canvas").getContext("2d");
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            ctx.lineWidth = 3;
            ctx.font = "30px serif";
            const scaleX = canvas.width / targetWidth;
            const scaleY = canvas.height / targetHeight;

            boxes.forEach(([x1, y1, x2, y2, label]) => {
                const classIndex = yolo_classes.indexOf(label);
                const color = classColors[classIndex];
                ctx.strokeStyle = color;
                ctx.fillStyle = color;
                const adjustedX1 = (x1 - offsetX) * scaleX;
                const adjustedY1 = (y1 - offsetY) * scaleY;
                const adjustedX2 = (x2 - offsetX) * scaleX;
                const adjustedY2 = (y2 - offsetY) * scaleY;
                const boxWidth = adjustedX2 - adjustedX1;
                const boxHeight = adjustedY2 - adjustedY1;
                ctx.strokeRect(adjustedX1, adjustedY1, boxWidth, boxHeight);
                const textWidth = ctx.measureText(label).width;
                ctx.fillRect(adjustedX1, adjustedY1 - 25, textWidth + 10, 25);
                ctx.fillStyle = "#FFFFFF";
                ctx.fillText(label, adjustedX1 + 5, adjustedY1 - 7);
            });
        }

        async function detectObjectsOnFrame(offscreenContext) {
            if (!net) {
                let device = await getDevice();
                if (!device) {
                    wgpuError.style.display = "block";
                    loadingContainer.style.display = "none";
                }
                net = await yolov8.load(device, "./net.safetensors");
                loadingContainer.style.display = "none";
            }
            let start = performance.now();
            const [input,img_width,img_height] = await prepareInput(offscreenContext);
            console.log("Preprocess took: " + (performance.now() - start) + " ms");
            start = performance.now();
            const output = await net(new Float32Array(input));
            console.log("Inference took: " + (performance.now() - start) + " ms");
            start = performance.now();
            let out = processOutput(output[0],img_width,img_height);
            console.log("Postprocess took: " + (performance.now() - start) + " ms");
            return out;
        }

        async function prepareInput(offscreenContext) {
            return new Promise(resolve => {
                const [img_width,img_height] = [modelInputSize, modelInputSize]
                const imgData = offscreenContext.getImageData(0,0,modelInputSize,modelInputSize);
                const pixels = imgData.data;
                const red = [], green = [], blue = [];

                for (let index=0; index<pixels.length; index+=4) {
                    red.push(pixels[index]/255.0);
                    green.push(pixels[index+1]/255.0);
                    blue.push(pixels[index+2]/255.0);
                }
                const input = [...red, ...green, ...blue];
                resolve([input, img_width, img_height])
            })
        }

        const getDevice = async () => {
            if (!navigator.gpu) return false;
            const adapter = await navigator.gpu.requestAdapter();
            return await adapter.requestDevice();
        };
        
        function processOutput(output, img_width, img_height) {
            let boxes = [];
            const numPredictions = Math.pow(modelInputSize/32, 2) * 21;
            for (let index=0;index<numPredictions;index++) {
                const [class_id,prob] = [...Array(80).keys()]
                    .map(col => [col, output[numPredictions*(col+4)+index]])
                    .reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);

                if (prob < 0.25) continue;
                const label = yolo_classes[class_id];
                const xc = output[index];
                const yc = output[numPredictions+index];
                const w = output[2*numPredictions+index];
                const h = output[3*numPredictions+index];
                const x1 = (xc-w/2)/modelInputSize*img_width;
                const y1 = (yc-h/2)/modelInputSize*img_height;
                const x2 = (xc+w/2)/modelInputSize*img_width;
                const y2 = (yc+h/2)/modelInputSize*img_height;
                boxes.push([x1,y1,x2,y2,label,prob]);
            }

            boxes = boxes.sort((box1,box2) => box2[5]-box1[5])
            const result = [];
            while (boxes.length>0) {
                result.push(boxes[0]);
                boxes = boxes.filter(box => iou(boxes[0],box)<0.7);
            }
            return result;
        }

        function iou(box1,box2) {
            return intersection(box1,box2)/union(box1,box2);
        }

        function union(box1,box2) {
            const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
            const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
            const box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
            const box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
            return box1_area + box2_area - intersection(box1,box2)
        }

        function intersection(box1,box2) {
            const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
            const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
            const x1 = Math.max(box1_x1,box2_x1);
            const y1 = Math.max(box1_y1,box2_y1);
            const x2 = Math.min(box1_x2,box2_x2);
            const y2 = Math.min(box1_y2,box2_y2);
            return (x2-x1)*(y2-y1)
        }

        const yolo_classes = [
            'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
            'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
            'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
            'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
            'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
            'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
            'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
            'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
        ];

        function generateColors(numColors) {
            const colors = [];
            for (let i = 0; i < 360; i += 360 / numColors) {
                colors.push(`hsl(${i}, 100%, 50%)`);
            }
            return colors;
        }

        const classColors = generateColors(yolo_classes.length);
    </script>
</body>
</html>
